//
// Created by Xing on 2022/3/10.
//

#ifndef CAFFE_INSTANCE_NORM_LAYER_HPP
#define CAFFE_INSTANCE_NORM_LAYER_HPP

#include <vector>

#include "caffe/blob.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"

namespace caffe {
    template<typename Dtype>
    class InstanceNormLayer : public Layer<Dtype> {
    public:
        explicit InstanceNormLayer(const LayerParameter &param)
                : Layer<Dtype>(param) {}
        virtual void LayerSetUp(const vector<Blob<Dtype>*> &bottom,
                                const vector<Blob<Dtype>*> &top);
        virtual void Reshape(const vector<Blob<Dtype>*> &bottom,
                             const vector<Blob<Dtype>*> &top);
        virtual inline const char *type() const { return "InstanceNorm"; }
        virtual inline int ExactNumBottomBlobs() const { return 1; }
        virtual inline int ExactNumTopBlobs() const { return 1; }

    protected:
        virtual void Forward_cpu(const vector<Blob<Dtype>*> &bottom,
                                 const vector<Blob<Dtype>*> &top);
        virtual void Forward_gpu(const vector<Blob<Dtype>*> &bottom,
                                 const vector<Blob<Dtype>*> &top) { NOT_IMPLEMENTED; }
        virtual void Backward_cpu(const vector<Blob<Dtype>*> &top,
                                  const vector<bool> &propagate_down,
                                  const vector<Blob<Dtype>*> &bottom) { NOT_IMPLEMENTED; }
        virtual void Backward_gpu(const vector<Blob<Dtype>*> &top,
                                  const vector<bool> &propagate_down,
                                  const vector<Blob<Dtype>*> &bottom) { NOT_IMPLEMENTED; }
    private:
        void multicast_cpu(int N, int C, int S, const Dtype *x, Dtype *y) {
            caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N * C, S, Dtype(1.0f),
                                  Dtype(1.0f), x, ones_HW_.cpu_data(), Dtype(0.0f), y);
        }

        void multicast_cpu_v2(int N, int C, int S, const Dtype *x, Dtype *y) {
            caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N, C, Dtype(1.0f), Dtype(1.0f),
                                  ones_N_.cpu_data(), x, Dtype(0.0f), temp_NC_.mutable_cpu_data());

            caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, N * C, S, Dtype(1.0f),
                                  Dtype(1.0f), temp_NC_.cpu_data(), ones_HW_.cpu_data(), Dtype(0.0f), y);
        }

        //  y[c] = sum x(.,c,...)
        void compute_sum_per_channel_cpu(int N, int C, int S, const Dtype *x, Dtype *y) {
            caffe_cpu_gemv<Dtype>(CblasNoTrans, N * C, S, Dtype(1.0f), x, ones_HW_.cpu_data(), Dtype(0.0f), y);
        }

        // y[c] = mean x(.,c,...)
        void compute_mean_per_channel_cpu(int N, int C, int S, const Dtype *x, Dtype *y) {
            Dtype F = 1. / S;
            compute_sum_per_channel_cpu(N, C, S, x, y);
            caffe_cpu_scale(N * C, F, y, y);
        }


    private:
        Dtype eps_;
        int channels_;
        bool scale_bias_;
        Blob<Dtype> mean_, var_, inv_var_, x_norm_;
        Blob<Dtype> ones_N_, ones_HW_, ones_C_, temp_C_, temp_NC_, temp_NCHW_;
    };
}  // namespace caffe

#endif //CAFFE_INSTANCE_NORM_LAYER_HPP
